import os
from alphalens import (plotting, utils)
import alphalens.performance as perf
import alphalens
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid") 
sns.set_context("poster")
DECIMAL_TO_BPS = 10000
plt.rcParams['font.sans-serif']=['SimHei']  #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False  #用来正常显示负号

def alpha_plot_all(factor_data, factor_name, long_short=True,
                            group_neutral=False,
                            by_group=False,):
    if not os.path.exists('picture/'):
        os.mkdir('picture/')
    if not os.path.exists('csv/'):
        os.mkdir('csv/')
    if not os.path.exists('picture/{}/'.format(factor_name)):
        os.mkdir('picture/{}/'.format(factor_name))
    if not os.path.exists('csv/{}/'.format(factor_name)):
        os.mkdir('csv/{}/'.format(factor_name))
    factor_returns = perf.factor_returns(factor_data,
                                             long_short,
                                             group_neutral)
    mean_quant_ret, std_quantile = \
            perf.mean_return_by_quantile(factor_data,
                                         by_group=False,
                                         demeaned=long_short,
                                         group_adjust=group_neutral)
    mean_quant_rateret = \
    mean_quant_ret.apply(alphalens.utils.rate_of_return, axis=0,
                                 base_period=mean_quant_ret.columns[0])

    mean_quant_ret_bydate, std_quant_daily = \
        perf.mean_return_by_quantile(factor_data,
                                     by_date=True,
                                     by_group=False,
                                     demeaned=long_short,
                                     group_adjust=group_neutral)
    mean_quant_rateret_bydate = mean_quant_ret_bydate.apply(
            alphalens.utils.rate_of_return, axis=0,
            base_period=mean_quant_ret_bydate.columns[0]
        )
    compstd_quant_daily = \
            std_quant_daily.apply(alphalens.utils.std_conversion, axis=0,
                                  base_period=std_quant_daily.columns[0])

    alpha_beta = perf.factor_alpha_beta(factor_data,
                                            factor_returns,
                                            long_short,
                                            group_neutral)
    mean_ret_spread_quant, std_spread_quant = \
        perf.compute_mean_returns_spread(mean_quant_rateret_bydate,
                                             factor_data['factor_quantile'].max(),
                                             factor_data['factor_quantile'].min(),
                                             std_err=compstd_quant_daily)
    #=================================================================================
    plotting.plot_returns_table(alpha_beta,
                                mean_quant_rateret,
                                mean_ret_spread_quant)
    returns_table = pd.DataFrame()
    returns_table = returns_table.append(alpha_beta)
    returns_table.loc["Mean Period Wise Return Top Quantile (bps)"] = \
        mean_quant_rateret.iloc[-1] * DECIMAL_TO_BPS
    returns_table.loc["Mean Period Wise Return Bottom Quantile (bps)"] = \
        mean_quant_rateret.iloc[0] * DECIMAL_TO_BPS
    returns_table.loc["Mean Period Wise Spread (bps)"] = \
        mean_ret_spread_quant.mean() * DECIMAL_TO_BPS

    returns_table = returns_table.apply(lambda x: x.round(3))
    returns_table = returns_table.reindex(columns=sorted(returns_table.columns, reverse=True))
    #===========================================================================

    #=====================================================================
    plotting.plot_quantile_statistics_table(factor_data)
    quantile_stats = factor_data.groupby('factor_quantile') \
        .agg(['min', 'max', 'mean', 'std', 'count'])['factor']
    quantile_stats['count %'] = quantile_stats['count'] \
        / quantile_stats['count'].sum() * 100.

    quantile_stats.to_csv('csv/{0}/quantile_stats.csv'.format(factor_name))
    #====================================================================
    plotting.plot_quantile_returns_bar(mean_quant_rateret,
                                           by_group=False,
                                           ylim_percentiles=None,
                                           )
    plt.savefig('picture/{0}/plot_quantile_returns_bar.jpg'.format(factor_name), bbox_inches = 'tight')

    plotting.plot_quantile_returns_violin(mean_quant_rateret_bydate,
                                              ylim_percentiles=(1, 99),
                                              )
    plt.savefig('picture/{0}/plot_quantile_returns_violin.jpg'.format(factor_name), bbox_inches = 'tight')

    for p in factor_returns:

        title = ('Factor Weighted '
                 + ('Group Neutral ' if group_neutral else '')
                 + ('Long/Short ' if long_short else '')
                 + "Portfolio Cumulative Return ({} Period)".format(p))

        plotting.plot_cumulative_returns(factor_returns[p],
                                         period=p,
                                         title=title,
                                         )
        plt.savefig('picture/{0}/{1}_plot_cumulative_returns.jpg'.format(factor_name, p), bbox_inches = 'tight')

        plotting.plot_cumulative_returns_by_quantile(mean_quant_ret_bydate[p],
                                                     period=p,
                                                     )
        plt.savefig('picture/{0}/{1}_plot_cumulative_returns_by_quantile.jpg'.format(factor_name, p), bbox_inches = 'tight')

    plotting.plot_mean_quantile_returns_spread_time_series(
            mean_ret_spread_quant,
            std_err=std_spread_quant,
            bandwidth=0.5,
        )
    plt.savefig('picture/{0}/plot_mean_quantile_returns_spread_time_series.jpg'.format(factor_name), bbox_inches = 'tight')








    ic = perf.factor_information_coefficient(factor_data, group_neutral)
    #####ic table############
    ic_summary_table = pd.DataFrame()
    ic_summary_table["IC Mean"] = ic.mean()
    ic_summary_table["IC Std."] = ic.std()
    ic_summary_table["Risk-Adjusted IC"] = \
        ic.mean() / ic.std()
    t_stat, p_value = stats.ttest_1samp(ic, 0)
    ic_summary_table["t-stat(IC)"] = t_stat
    ic_summary_table["p-value(IC)"] = p_value
    ic_summary_table["IC Skew"] = stats.skew(ic)
    ic_summary_table["IC Kurtosis"] = stats.kurtosis(ic)
    # plotting.plot_information_table(ic)
    ic_summary_table = ic_summary_table.apply(lambda x: x.round(3)).T
    ic_summary_table = ic_summary_table.reindex(columns=sorted(ic_summary_table.columns, reverse=True))
    #=============================================
    plotting.plot_ic_ts(ic)
    plt.savefig('picture/{0}/plot_ic_ts.jpg'.format(factor_name), bbox_inches = 'tight')

    plotting.plot_ic_hist(ic)
    plt.savefig('picture/{0}/plot_ic_hist.jpg'.format(factor_name), bbox_inches = 'tight')

    plotting.plot_ic_qq(ic)
    plt.savefig('picture/{0}/plot_ic_qq.jpg'.format(factor_name), bbox_inches = 'tight')

    if not by_group:

        mean_monthly_ic = \
            perf.mean_information_coefficient(factor_data,
                                              group_adjust=group_neutral,
                                              by_group=False,
                                              by_time="M")
        plotting.plot_monthly_ic_heatmap(mean_monthly_ic)
        plt.savefig('picture/{0}/plot_monthly_ic_heatmap.jpg'.format(factor_name), bbox_inches = 'tight')

    if by_group:
        mean_group_ic = \
            perf.mean_information_coefficient(factor_data,
                                              group_adjust=group_neutral,
                                              by_group=True)

        plotting.plot_ic_by_group(mean_group_ic)
        plt.savefig('picture/{0}/plot_ic_by_group.jpg'.format(factor_name), bbox_inches = 'tight')

    turnover_periods = utils.get_forward_returns_columns(factor_data.columns)
    quantile_factor = factor_data['factor_quantile']
    quantile_turnover = \
            {p: pd.concat([perf.quantile_turnover(quantile_factor, q, p)
                           for q in range(1, int(quantile_factor.max()) + 1)],
                          axis=1)
                for p in turnover_periods}
    autocorrelation = pd.concat(
            [perf.factor_rank_autocorrelation(factor_data, period) for period in
             turnover_periods], axis=1)
    #====================turnover===========================================
    plotting.plot_turnover_table(autocorrelation, quantile_turnover)
    turnover_table = pd.DataFrame()
    for period in sorted(quantile_turnover.keys()):
        for quantile, p_data in quantile_turnover[period].iteritems():
            turnover_table.loc["Quantile {} Mean Turnover ".format(quantile),
                               "{}".format(period)] = p_data.mean()
    auto_corr = pd.DataFrame()
    for period, p_data in autocorrelation.iteritems():
        auto_corr.loc["Mean Factor Rank Autocorrelation",
                      "{}".format(period)] = p_data.mean()

    turnover_table = turnover_table.apply(lambda x: x.round(3))
    turnover_table = turnover_table.reindex(columns=sorted(turnover_table.columns, reverse=True))
    auto_corr = (auto_corr.apply(lambda x: x.round(3)))
    auto_corr = auto_corr.reindex(columns=sorted(auto_corr.columns, reverse=True))
    #=============================================================

    for period in turnover_periods:
        plotting.plot_top_bottom_quantile_turnover(quantile_turnover[period],
                                                       period=period)
        plt.savefig('picture/{0}/{1}_plot_top_bottom_quantile_turnover.jpg'.format(factor_name, period), bbox_inches = 'tight')

    for period in autocorrelation:
        plotting.plot_factor_rank_auto_correlation(autocorrelation[period],
                                                   period=period)
        plt.savefig('picture/{0}/{1}_plot_factor_rank_auto_correlation.jpg'.format(factor_name, period), bbox_inches = 'tight')

    df = pd.concat([returns_table, ic_summary_table, turnover_table, auto_corr], axis=0)
    df.to_csv('csv/{0}/all_data.csv'.format(factor_name))
    


def alpha_plot_by_group(factor_data, factor_name, long_short=True,
                            group_neutral=True,
                            by_group=True,):

    if by_group:
        mean_group_ic = \
            perf.mean_information_coefficient(factor_data,
                                              group_adjust=group_neutral,
                                              by_group=True)

        plotting.plot_ic_by_group(mean_group_ic)
        plt.savefig('picture/{0}/plot_ic_by_group.jpg'.format(factor_name), bbox_inches = 'tight')


        mean_return_quantile_group, mean_return_quantile_group_std_err = \
            perf.mean_return_by_quantile(factor_data,
                                         by_date=False,
                                         by_group=True,
                                         demeaned=long_short,
                                         group_adjust=group_neutral)

        mean_quant_rateret_group = mean_return_quantile_group.apply(
            utils.rate_of_return, axis=0,
            base_period=mean_return_quantile_group.columns[0]
        )

        # num_groups = len(mean_quant_rateret_group.index
        #                  .get_level_values('group').unique())

        # vertical_sections = 1 + (((num_groups - 1) // 2) + 1)
        # gf = GridFigure(rows=vertical_sections, cols=2)

        # ax_quantile_returns_bar_by_group = [gf.next_cell()
        #                                     for _ in range(num_groups)]
        plotting.plot_quantile_returns_bar(mean_quant_rateret_group,
                                           by_group=True,
                                           ylim_percentiles=(5, 95),
                                           )
        plt.savefig('picture/{0}/plot_quantile_returns_bar.jpg'.format(factor_name), bbox_inches = 'tight')